import paddle
import paddlefsl
from paddlefsl.model_zoo import maml

# Set computing device
paddle.set_device('gpu:0')


# Config: MAML, Mini-ImageNet, Conv, 5 Ways, 1 Shot
TRAIN_DATASET = paddlefsl.datasets.MiniImageNet(mode='train')
VALID_DATASET = paddlefsl.datasets.MiniImageNet(mode='valid')
TEST_DATASET = paddlefsl.datasets.MiniImageNet(mode='test')
WAYS = 5
SHOTS = 1
MODEL = paddlefsl.backbones.Conv(input_size=(3, 84, 84), output_size=WAYS, conv_channels=[32, 32, 32, 32])
META_LR = 0.002
INNER_LR = 0.03
ITERATIONS = 30000
TEST_EPOCH = 10
META_BATCH_SIZE = 32
TRAIN_INNER_ADAPT_STEPS = 5
TEST_INNER_ADAPT_STEPS = 2
APPROXIMATE = True
REPORT_ITER = 1
SAVE_MODEL_ITER = 20
SAVE_MODEL_ROOT = r'D:\cp\fsl_train_model\maml'
TEST_PARAM_FILE = 'iteration' + str(ITERATIONS) + '.params'


def main():
    # train_dir = maml.meta_training(train_dataset=TRAIN_DATASET,
    #                                valid_dataset=VALID_DATASET,
    #                                ways=WAYS,
    #                                shots=SHOTS,
    #                                model=MODEL,
    #                                meta_lr=META_LR,
    #                                inner_lr=INNER_LR,
    #                                iterations=ITERATIONS,
    #                                meta_batch_size=META_BATCH_SIZE,
    #                                inner_adapt_steps=TRAIN_INNER_ADAPT_STEPS,
    #                                approximate=APPROXIMATE,
    #                                report_iter=REPORT_ITER,
    #                                save_model_iter=SAVE_MODEL_ITER,
    #                                save_model_root=SAVE_MODEL_ROOT)
    #
    # print(train_dir)
    # state_dict = paddle.load(train_dir + '/' + TEST_PARAM_FILE)

    # 验证
    state_dict = paddle.load(r"D:\cp\fsl_train_model\maml\maml_miniimagenet_conv_5ways_1shots"
                             r"\metalr0.002_innerlr0.03_batchsize32_adaptsteps5_approximate"
                             r"\iteration10840.params")
    MODEL.load_dict(state_dict)
    maml.meta_testing(model=MODEL,
                      test_dataset=TEST_DATASET,
                      test_epoch=TEST_EPOCH,
                      test_batch_size=META_BATCH_SIZE,
                      ways=WAYS,
                      shots=SHOTS,
                      inner_lr=INNER_LR,
                      inner_adapt_steps=TEST_INNER_ADAPT_STEPS,
                      approximate=APPROXIMATE)


if __name__ == '__main__':
    main()
